import re
from typing import Any, Dict, List

from evalscope.api.benchmark import BenchmarkMeta, DefaultDataAdapter
from evalscope.api.dataset import Sample
from evalscope.api.evaluator import TaskState
from evalscope.api.messages import ChatMessageUser, ContentText
from evalscope.api.metric.scorer import AggScore, SampleScore, Score
from evalscope.api.registry import register_benchmark
from evalscope.constants import Tags
from evalscope.utils.logger import get_logger
from evalscope.utils.multi_choices import parse_answers, prompt

logger = get_logger()

DESCRIPTION = (
    'Drivelology, a unique linguistic phenomenon characterised as "nonsense with depth" - '
    'utterances that are syntactically coherent yet pragmatically paradoxical, emotionally loaded, '
    'or rhetorically subversive.'
)

MULTIPLE_ANSWER_TEMPLATE = r"""
#Instruction#:
Classify the given text into one or more of the following categories: inversion, wordplay, switchbait, paradox, and misdirection.

#Definitions#:
- inversion: This technique takes a well-known phrase, cliché, or social script and flips it on its head. The humour arises by reversing a familiar structure to creating a new, often satirical, meaning.
- wordplay: This is the use of linguistic creativity, often by exploiting the phonetics or polysemy of words. It includes puns, double entendres, and similarities.
- switchbait: This technique hinges on a specific phrase (the "bait") that has a culturally-embedded double meaning. The initial context is then suddenly replaced (the "switch") by a surprising second meaning. The humour is generated by this cynical or culturally-specific reinterpretation of the bait, rather than by derailing a narrative.
- paradox: This relies on a statement that appears logically self-contradictory but contains a latent, often humorous or profound truth. The core of the technique is the clash of seemingly incompatible ideas.
- misdirection: This technique leads the listener down an expected path before a final twist reveals a different, often more literal or absurd, ending.

Answer the following multiple choice question where multiple answers may be correct.
The entire content of your response should be of the following format: 'ANSWER: [LETTERS]' (without quotes) where [LETTER]S is one or more of {letters}.

{question}

{choices}
""".strip()  # noqa: E501


@register_benchmark(
    BenchmarkMeta(
        name='drivel_multilabel',
        pretty_name='DrivelologyMultilabelClassification',
        tags=[Tags.MULTIPLE_CHOICE],
        description=DESCRIPTION.strip(),
        dataset_id='extraordinarylab/drivel-hub',
        subset_list=['multi-label-classification'],
        metric_list=['f1_weighted', 'f1_micro', 'f1_macro', 'exact_match'],
        aggregation='f1_weighted',
        eval_split='test',
        prompt_template='{question}',
    )
)
class DrivelologyMultilabelClassificationAdapter(DefaultDataAdapter):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.categories = ['inversion', 'wordplay', 'switchbait', 'paradox', 'misdirection']
        self.choices = {'A': 'inversion', 'B': 'wordplay', 'C': 'switchbait', 'D': 'paradox', 'E': 'misdirection'}
        self.categories_to_letters = {v: k for k, v in self.choices.items()}

    def record_to_sample(self, record: Dict[str, Any]) -> Sample:
        text: str = record['text']
        label: List[str] = record['label']
        question = f'Text to classify: {text}'
        choices_list = [f'{key}. {value}' for key, value in self.choices.items()]
        input_text = prompt(question=question, choices=choices_list, template=MULTIPLE_ANSWER_TEMPLATE)
        content_list = [ContentText(text=input_text)]
        target_letters = ''.join(
            sorted([self.categories_to_letters[cat] for cat in label if cat in self.categories_to_letters])
        )
        metadata = {'text': text, 'label': label, 'target_letters': target_letters}
        return Sample(
            input=[ChatMessageUser(content=content_list)],
            choices=choices_list,
            target=target_letters,
            metadata=metadata,
        )

    def extract_answer(self, prediction: str, task_state: TaskState) -> str:
        pattern = r'ANSWER:\s*([A-E]+)'
        match = re.search(pattern, prediction)
        if match:
            letters = match.group(1).strip().upper()
            return ''.join(sorted(set(letters)))
        else:
            try:
                answers = parse_answers(prediction)
                return ''.join(sorted(list(answers)))
            except Exception as e:
                logger.warning(f'Could not extract answer from: {prediction}. Error: {e}')
                return ''

    def match_score(
        self, original_prediction: str, filtered_prediction: str, reference: str, task_state: TaskState
    ) -> Score:
        """
        Calculate the match score between the prediction and reference for multilabel classification.

        Args:
            original_prediction: The original model output
            filtered_prediction: The extracted answer (letter format, e.g., "AC")
            reference: The reference answer (letter format, e.g., "AC")
            task_state: The current task state

        Returns:
            Score object with metrics
        """
        # Create a Score object as required by the API
        score = Score(
            extracted_prediction=filtered_prediction,
            prediction=original_prediction,
        )

        # Convert letter answers to category sets
        pred_categories = set(self.choices.get(letter, '') for letter in filtered_prediction)
        target_categories = set(self.choices.get(letter, '') for letter in reference)

        # Remove empty strings (may be caused by invalid letters)
        pred_categories = {cat for cat in pred_categories if cat}
        target_categories = {cat for cat in target_categories if cat}

        # Calculate TP (true positives), FP (false positives), and FN (false negatives)
        tp = len(pred_categories & target_categories)  # intersection
        fp = len(pred_categories - target_categories)  # in prediction but not in target
        fn = len(target_categories - pred_categories)  # in target but not in prediction

        # Calculate precision, recall and F1 score
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0

        # Calculate exact match (1.0 if prediction exactly matches target)
        exact_match = 1.0 if pred_categories == target_categories else 0.0

        # Store category information in metadata for later aggregation
        category_data = {}
        for cat in self.categories:
            in_pred = cat in pred_categories
            in_target = cat in target_categories

            category_data[cat] = {
                'tp': 1 if in_pred and in_target else 0,
                'fp': 1 if in_pred and not in_target else 0,
                'fn': 1 if not in_pred and in_target else 0,
                'support': 1 if in_target else 0
            }

        # Set simple numerical values in score.value as expected by the API
        score.value = {'f1': f1, 'precision': precision, 'recall': recall, 'exact_match': exact_match}

        # Store category data in metadata for aggregation
        score.metadata = {'category_data': category_data}

        return score

    def aggregate_scores(self, sample_scores: List[SampleScore]) -> List[AggScore]:
        """
        Aggregate scores across all samples.
        Computes weighted, macro, and micro F1 scores for multilabel classification.

        Args:
            sample_scores: List of sample scores

        Returns:
            List of aggregated scores
        """
        if not sample_scores:
            return [
                AggScore(metric_name='f1_weighted', score=0.0, num=0, metadata={}),
                AggScore(metric_name='f1_micro', score=0.0, num=0, metadata={}),
                AggScore(metric_name='f1_macro', score=0.0, num=0, metadata={}),
                AggScore(metric_name='exact_match', score=0.0, num=0, metadata={})
            ]

        # Initialize category statistics
        category_stats = {cat: {'tp': 0, 'fp': 0, 'fn': 0, 'support': 0} for cat in self.categories}
        total_exact_matches = 0
        num_samples = len(sample_scores)

        # Aggregate statistics across all samples
        for ss in sample_scores:
            # Add exact match score to total
            total_exact_matches += ss.score.value.get('exact_match', 0)

            # Get category data from metadata
            if 'category_data' in ss.score.metadata:
                cat_data = ss.score.metadata['category_data']
                for cat, stats in cat_data.items():
                    if cat in self.categories:
                        category_stats[cat]['tp'] += stats.get('tp', 0)
                        category_stats[cat]['fp'] += stats.get('fp', 0)
                        category_stats[cat]['fn'] += stats.get('fn', 0)
                        category_stats[cat]['support'] += stats.get('support', 0)

        # Calculate F1 scores for each category
        category_f1 = {}
        total_support = sum(stats['support'] for stats in category_stats.values())
        f1_sum = 0.0

        for cat, stats in category_stats.items():
            tp = stats['tp']
            fp = stats['fp']
            fn = stats['fn']

            precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
            f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0

            category_f1[cat] = f1
            f1_sum += f1

        # Calculate micro-average F1 (based on aggregate TP, FP, FN)
        total_tp = sum(stats['tp'] for stats in category_stats.values())
        total_fp = sum(stats['fp'] for stats in category_stats.values())
        total_fn = sum(stats['fn'] for stats in category_stats.values())

        micro_precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
        micro_recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
        f1_micro = 2 * micro_precision * micro_recall / (micro_precision + micro_recall) if (
            micro_precision + micro_recall
        ) > 0 else 0.0

        # Calculate macro-average F1 (simple average of category F1 scores)
        f1_macro = f1_sum / len(self.categories) if self.categories else 0.0

        # Calculate weighted-average F1 (weighted by support)
        f1_weighted = 0.0
        if total_support > 0:
            for cat, stats in category_stats.items():
                cat_f1 = category_f1[cat]
                weight = stats['support'] / total_support
                f1_weighted += cat_f1 * weight

        # Calculate accuracy (proportion of exact matches)
        exact_match = total_exact_matches / num_samples

        # Return list of aggregate scores
        return [
            AggScore(
                metric_name='f1_weighted',
                score=f1_weighted,
                num=num_samples,
                metadata={'category_f1': {
                    cat: f1
                    for cat, f1 in category_f1.items()
                }}
            ),
            AggScore(metric_name='f1_micro', score=f1_micro, num=num_samples, metadata={}),
            AggScore(metric_name='f1_macro', score=f1_macro, num=num_samples, metadata={}),
            AggScore(metric_name='exact_match', score=exact_match, num=num_samples, metadata={})
        ]
